{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2db540bc",
   "metadata": {},
   "source": [
    "# Fitting a 3D Bounding Box using a Differentiable Renderer\n",
    "\n",
    "Differentiable rendering enables optimization of 3D object properties like the geometry of a mesh. Unlike traditional rendering, differentiable rendering can backpropagate gradients from image space to 3D geometry.\n",
    "\n",
    "In this tutorial we fit a 3D bounding box around an object given ground truth images of the object. We assume images come with a segmentation mask of the object of interest and include camera extrinsics. During fitting we optimize a 3D bounding box to fit tightly around the object, despite never seeing any 3D ground truth data.\n",
    "\n",
    "We leverage kaolin's rendering capabilities with meshes, particularly the DIB-R rasterizer in `kaolin.render.mesh.dibr_rasterization` ([API documentation](https://kaolin.readthedocs.io/en/latest/modules/kaolin.render.mesh.html)). Please note this notebook is intended to illustrate key kaolin functionality and is not intended for use in a production system. A more sophisticated use of differentiable rendering can be found in the [Optimizing a mesh using a Differentiable Renderer notebook](examples/tutorial/dibr_tutorial.ipynb). Additional DIB-R examples can be found in [this repository](https://github.com/nv-tlabs/DIB-R-Single-Image-3D-Reconstruction).\n",
    "\n",
    "Before starting the tutorial uncompress the training data found in [examples/samples/rendered_clock.zip](examples/samples/rendered_clock.zip)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98082a4a-21da-4774-a5f5-02c17f8e102a",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "736eb3ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import glob\n",
    "import os\n",
    "from pathlib import Path\n",
    "from typing import List, Tuple\n",
    "\n",
    "import matplotlib.animation as animation\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import Tensor\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "try:\n",
    "    ipy_str = str(type(get_ipython()))\n",
    "    if 'zmqshell' in ipy_str:\n",
    "        %matplotlib notebook\n",
    "finally:\n",
    "    pass\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import kaolin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ce0ee9a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set hyperparameters\n",
    "batch_size_hyper = 2\n",
    "mask_weight_hyper = 1.0\n",
    "mask_occupancy_hyper = 0.05\n",
    "mask_overlap_hyper = 1.0\n",
    "lr = 5e-2\n",
    "scheduler_step_size = 5\n",
    "scheduler_gamma = 0.5\n",
    "num_epoch = 30"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06e8b4ee",
   "metadata": {},
   "source": [
    "# Preparing training data\n",
    "\n",
    "First we prepare the training 2D image data and differentiable 3D bounding box mesh.\n",
    "\n",
    "## Generating training data\n",
    "\n",
    "Our training data consists of segmentation masks of objects with known camera properties. One way to generate this data is to use the Data Generator in the [Omniverse Kaolin App](https://docs.omniverse.nvidia.com/app_kaolin/app_kaolin/user_manual.html#data-generator). We provide sample output from the app in `examples/samples/`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c69d6434",
   "metadata": {},
   "source": [
    "## Parsing synthetic data\n",
    "\n",
    "We first need to parse the synthetic data generated by the Omniverse app.\n",
    "The app produces one output file for each category of data - among depth map, RGB image, or segmentation map - along with an additional metadata JSON file. The JSON file includes the camera_properties, with data related to the camera settings including \"clipping_range\", \"horizontal_aperture\", \"focal_length\", \"tf_mat\".\n",
    "\n",
    "Below we parse this training data using the `kaolin.io.render.import_synthetic_view` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fc9741fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set dataset of renders and the 3D bounding box mesh\n",
    "rendered_path = Path(\"../samples/rendered_clock/\")\n",
    "\n",
    "# prepare the 2D render dataset including camera extrinsics:\n",
    "num_views = len(glob.glob(os.path.join(rendered_path, \"*_rgb.png\")))\n",
    "train_data = []\n",
    "for i in range(num_views):\n",
    "    data = kaolin.io.render.import_synthetic_view(rendered_path, i, rgb=True, semantic=True)\n",
    "    train_data.append(data)\n",
    "dataloader = DataLoader(train_data, batch_size=batch_size_hyper, shuffle=True, pin_memory=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a238c89e",
   "metadata": {},
   "source": [
    "## Defining the 3D bounding box\n",
    "\n",
    "We will use a differentiable renderer to create 2D renders of a mesh, specifically for a 3D bounding box. In 3D bounding box estimation our goal is to fit the 3D box around an object while keeping the box as some cuboid. To enforce this constraint we will take a fixed 3D cuboid and apply transformations that rotate, translate, or scale the box. Note that this constrains learning by preventing and shearing or other transformations that would deform the cuboid shape.\n",
    "\n",
    "We create this model by loading a 3D cube mesh (provided for this tutorial in [examples/samples/bbox.obj](examples/samples/bbox.obj)) and defining a set of learnable parameters for the object rotation, translation, and scaling. When we access the vertices of the object we apply those transforms and return the modified cuboid."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "56f6020e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DifferentiableBBox:\n",
    "    \"\"\"Represents a differentiable 3D bounding box.\n",
    "\n",
    "    Box is parametrized in terms of a fixed mesh and a learned (optimized)\n",
    "    transformation in terms of rotation, scaling, and translation.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        \"\"\"Construct a 3D bounding from a mesh filepath.\n",
    "\n",
    "        Args:\n",
    "            mesh_file (str): Filepath to load the mesh. OBJ format.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        mesh_file = Path(\"../samples/bbox.obj\")\n",
    "        assert mesh_file.exists(), f\"File missing: {mesh_file.absolute()}.\"\n",
    "        \n",
    "\n",
    "        self._mesh = kaolin.io.obj.import_mesh(mesh_file, with_materials=True)\n",
    "        self._vertices = None\n",
    "        self._faces = None\n",
    "        \n",
    "        # define the learnable parameters\n",
    "        self._centers = torch.zeros((3,), dtype=torch.float, device=\"cuda\", requires_grad=True)\n",
    "        self._scales = torch.ones((3,), dtype=torch.float, device=\"cuda\", requires_grad=True)\n",
    "        self._rotations = torch.tensor([0.0, 0.0, 0.0, 1.0], device=\"cuda\", requires_grad=True)  # rotations as quaternion\n",
    "        \n",
    "        # prepare the vertices for learning\n",
    "        self._preprocess()\n",
    "\n",
    "    def _preprocess(self):\n",
    "        # scale up vertices to roughly the image scale\n",
    "        self._vertices = self._mesh.vertices.cuda().unsqueeze(0) * 0.75\n",
    "        \n",
    "        # disable gradient propagation to the vertices to fix the baseline 3D bounding box\n",
    "        self._vertices.requires_grad = False\n",
    "        \n",
    "        self._faces = self._mesh.faces.cuda()\n",
    "\n",
    "    @property\n",
    "    def vertices(self):\n",
    "        # convert the rotation quaternion to a 3x3 rotation matrix (see below)\n",
    "        rot = quaternion_to_matrix33(self._rotations)\n",
    "        # apply scaling, rotation, and translation to vertices\n",
    "        return (torch.matmul(self._vertices, rot) * self._scales) + self._centers\n",
    "\n",
    "    @property\n",
    "    def faces(self):\n",
    "        return self._faces\n",
    "    \n",
    "    @property\n",
    "    def parameters(self):\n",
    "        return [self._centers, self._scales, self._rotations]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20fc26be",
   "metadata": {},
   "source": [
    "### Mesh Rotation\n",
    "\n",
    "The mesh parameters represent the rotation as a quaternion. The image data represents points a 3D vectors, so for convenience we define a helper function that converts the quaternion rotation representation into a 3D rotation matrix. This simplifies the mesh rotation operation to a matrix multiplication.\n",
    "\n",
    "(As an alternative you could also modify the `vertices` property of `DifferentiableBBox` to rotate the vertices using the quaternion directly.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c3c198ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper functions to convert from quaternion to rotation matrix\n",
    "def vector_normalize(vec: Tensor) -> Tensor:\n",
    "    \"\"\"Normalize a 1d vector using the L2 norm.\n",
    "\n",
    "    Args:\n",
    "        vec (Tensor): A 1d vector of shape (b,1).\n",
    "\n",
    "    Returns:\n",
    "        Tensor: A normalized version of the input vector of shape (b,1).\n",
    "    \"\"\"    \n",
    "    return vec / vec.norm(p=2, dim=-1, keepdim=True)\n",
    "\n",
    "\n",
    "def quaternion_to_matrix33(quat: Tensor) -> Tensor:\n",
    "    \"\"\"Convert a quaternion to a 3x3 rotation matrix.\n",
    "\n",
    "    Args:\n",
    "        quat (Tensor): Rotation quaternion of shape (4).\n",
    "\n",
    "    Returns:\n",
    "        Tensor: Rotation matrix of shape (3,3).\n",
    "    \"\"\"    \n",
    "    # reference: http://www.euclideanspace.com/maths/geometry/rotations/conversions/quaternionToMatrix/index.htm\n",
    "    q = vector_normalize(quat)\n",
    "    \n",
    "    qx, qy, qz, qw = q[0], q[1], q[2], q[3]\n",
    "    sqw = qw ** 2\n",
    "    sqx = qx ** 2\n",
    "    sqy = qy ** 2\n",
    "    sqz = qz ** 2\n",
    "    qxy = qx * qy\n",
    "    qzw = qz * qw\n",
    "    qxz = qx * qz\n",
    "    qyw = qy * qw\n",
    "    qyz = qy * qz\n",
    "    qxw = qx * qw\n",
    "\n",
    "    invs = 1 / (sqx + sqy + sqz + sqw)\n",
    "    m00 = (sqx - sqy - sqz + sqw) * invs\n",
    "    m11 = (-sqx + sqy - sqz + sqw) * invs\n",
    "    m22 = (-sqx - sqy + sqz + sqw) * invs\n",
    "    m10 = 2.0 * (qxy + qzw) * invs\n",
    "    m01 = 2.0 * (qxy - qzw) * invs\n",
    "    m20 = 2.0 * (qxz - qyw) * invs\n",
    "    m02 = 2.0 * (qxz + qyw) * invs\n",
    "    m21 = 2.0 * (qyz + qxw) * invs\n",
    "    m12 = 2.0 * (qyz - qxw) * invs\n",
    "    r0 = torch.stack([m00, m01, m02])\n",
    "    r1 = torch.stack([m10, m11, m12])\n",
    "    r2 = torch.stack([m20, m21, m22])\n",
    "    mat33 = torch.stack([r0, r1, r2]).T\n",
    "    return mat33"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbcbb6ba",
   "metadata": {},
   "source": [
    "# Setting up the fitting loop\n",
    "\n",
    "In this section we set up the optimization and losses for training.\n",
    "\n",
    "## Setting up the optimizer\n",
    "\n",
    "Below we create the 3D bounding box and prepare an optimizer to tune the box's parameters. We add a learning rate scheduler so that we can gradually decrease the step size of changes made to the box parameters during the fitting process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "56085cce",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up model & optimization parameters\n",
    "bbox = DifferentiableBBox()\n",
    "optim = torch.optim.Adam(params=bbox.parameters, lr=lr)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=scheduler_step_size, gamma=scheduler_gamma)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8979dd00",
   "metadata": {},
   "source": [
    "## Differentiable rendering\n",
    "\n",
    "During training we want to optimize the 3D bounding box to overlap with the semantic segmentation mask of the object in the 2D renders. We first need to project the 3D mesh into the 2D image space for each image. The function below uses ground truth data on the camera properties to project the 3D mesh vertices onto a 2D image to produce a 2D silhouette of the mesh."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2b6bda5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def project_to_2d(\n",
    "    bbox: DifferentiableBBox,\n",
    "    batch_size: int,\n",
    "    image_shape: Tuple[int, int],\n",
    "    camera_transform: Tensor,\n",
    "    camera_projection: Tensor,\n",
    ") -> Tensor:\n",
    "    \"\"\"Render a mesh onto a 2D image given a viewing camera's transform and projection.\n",
    "\n",
    "    Args:\n",
    "        bbox (DifferentiableBBox): Differentiable bounding box representing 3D mesh.\n",
    "        batch_size (int): Number of elements in the data batch.\n",
    "        image_shape (Tuple[int, int]): Tuple of image dimensions (height, width).\n",
    "        camera_transform (Tensor): Camera transform of shape (b, 4, 3).\n",
    "        camera_projection (Tensor): Camera projection of shape (b, 3, 1).\n",
    "\n",
    "    Returns:\n",
    "        Tensor: Soft mask for mesh silhouette of shape (b, h, w). (h, w) given by `image_shape`.\n",
    "    \"\"\"\n",
    "    (face_vertices_camera, face_vertices_image, face_normals,) = kaolin.render.mesh.prepare_vertices(\n",
    "        bbox.vertices.repeat(batch_size, 1, 1),\n",
    "        bbox.faces,\n",
    "        camera_projection,\n",
    "        camera_transform=camera_transform,\n",
    "    )\n",
    "\n",
    "    nb_faces = bbox.faces.shape[0]\n",
    "    face_attributes = [torch.ones((batch_size, nb_faces, 3, 1), device=\"cuda\")]\n",
    "\n",
    "    image_features, soft_mask, face_idx = kaolin.render.mesh.dibr_rasterization(\n",
    "        image_shape[0],\n",
    "        image_shape[1],\n",
    "        face_vertices_camera[:, :, :, -1],\n",
    "        face_vertices_image,\n",
    "        face_attributes,\n",
    "        face_normals[:, :, -1],\n",
    "    )\n",
    "    return soft_mask  # only aligning images by 2D silhouette\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba4bfddb",
   "metadata": {},
   "source": [
    "## Defining losses\n",
    "\n",
    "Assuming that we have a 2D mask of the mesh in the image and the ground truth segmentation mask for that same image we can compare them. We define two losses to encourage the mesh to fit around the target object tightly:\n",
    "1. `overlap` is a loss that encourages the two masks to overlap as much as possible.\n",
    "2. `occupany` is a loss that encourages the 2D projection mask to be as large as possible.\n",
    "\n",
    "With overlap alone the projection will not stay outside the shape: it's better to cut off a few corners of the object in exchange for slightly less penalty from the empty space that's saved. Adding occupany counter-balances this effect. The end result is the mesh will try to fit around the outside of the shape (convex hull)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "75bc78d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def overlap(lhs_mask: Tensor, rhs_mask: Tensor) -> Tensor:\n",
    "    \"\"\"Compute the overlap of two 2D masks as the intersection over union.\n",
    "\n",
    "    Args:\n",
    "        lhs_mask (Tensor): 2D mask of shape (b, h, w).\n",
    "        rhs_mask (Tensor): 2D mask of shape (b, h, w).\n",
    "\n",
    "    Returns:\n",
    "        Tensor: Fraction of overlap of the two masks of shape (1). Averaged over batch samples.\n",
    "    \"\"\"    \n",
    "    batch_size, height, width = lhs_mask.shape\n",
    "    assert rhs_mask.shape == lhs_mask.shape\n",
    "    sil_mul = lhs_mask * rhs_mask\n",
    "    sil_area = torch.sum(sil_mul.reshape(batch_size, -1), dim=1)\n",
    "\n",
    "    return 1 - torch.mean(sil_area / (height * width))\n",
    "\n",
    "\n",
    "def occupancy(mask: Tensor) -> Tensor:\n",
    "    \"\"\"Compute what fraction of a total image is occupied by a 2D mask.\n",
    "\n",
    "    Args:\n",
    "        mask (Tensor): 2D mask of shape (b, h, w).\n",
    "\n",
    "    Returns:\n",
    "        Tensor: Fraction of the full image occupied by the mask of shape (1). Averaged over batch samples.\n",
    "    \"\"\"    \n",
    "    batch_size, height, width = mask.shape\n",
    "    mask_area = torch.sum(mask.reshape(batch_size, -1), dim=1)\n",
    "    return torch.mean(mask_area / (height * width))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e48917f",
   "metadata": {},
   "source": [
    "## Visualization utilities\n",
    "\n",
    "During training it's helpful to watch the progression of bounding box fitting. Below we define helper functions to plot images that show the bounding box silhouette compared to the ground truth silhouette."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "eb5748e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_image(gt_mask: Tensor, pred_mask: Tensor) -> np.ndarray:\n",
    "    \"\"\"Compute an image array showing ground truth vs predicted masks.\n",
    "\n",
    "    Args:\n",
    "        gt_mask (Tensor): Ground truth mask of shape (w, h).\n",
    "        pred_mask (Tensor): Predicted mask of shape (w, h).\n",
    "\n",
    "    Returns:\n",
    "        np.ndarray: [0,1] normalized numpy array of shape (w, h, 3).\n",
    "    \"\"\"    \n",
    "    # mask shape is [w,h]\n",
    "    canvas = np.zeros((gt_mask.shape[0], gt_mask.shape[1], 3))\n",
    "    canvas[..., 2] = pred_mask.cpu().detach().numpy()\n",
    "    canvas[..., 1] = gt_mask.cpu().detach().numpy()\n",
    "\n",
    "    return np.clip(canvas, 0.0, 1.0)\n",
    "\n",
    "\n",
    "def show_renders(bbox: DifferentiableBBox, dataset: Tensor) -> List[np.ndarray]:\n",
    "    \"\"\"Generate images comparing ground truth and predicted semantic segmentations for a given mesh.\n",
    "\n",
    "    The mesh is projected to 2D to match the camera views of each ground truth 2D render.\n",
    "    Images plot the true silhouette of the object overlayed with the mesh silhouette.\n",
    "\n",
    "    Args:\n",
    "        bbox (DifferentiableBBox): Differentiable bounding box representing 3D mesh.\n",
    "        dataset (Tensor): Batch of ground truth 2D renders with camera extrinsics.\n",
    "\n",
    "    Returns:\n",
    "        List[np.ndarray]: [0,1] normalized images comparing ground truth and mesh silhouette.\n",
    "    \"\"\"    \n",
    "    with torch.no_grad():\n",
    "        images = []\n",
    "        for sample in dataset:\n",
    "            gt_mask = sample[\"semantic\"].cuda()\n",
    "            camera_transform = sample[\"metadata\"][\"cam_transform\"].cuda()\n",
    "            camera_projection = sample[\"metadata\"][\"cam_proj\"].cuda()\n",
    "\n",
    "            # project model mesh onto 2D image\n",
    "            img_shape = (gt_mask.shape[0], gt_mask.shape[1])\n",
    "            soft_mask = project_to_2d(\n",
    "                bbox,\n",
    "                batch_size=1,\n",
    "                image_shape=img_shape,\n",
    "                camera_transform=camera_transform,\n",
    "                camera_projection=camera_projection,\n",
    "            )\n",
    "            image = draw_image(gt_mask.squeeze(), soft_mask)\n",
    "            images.append(image)\n",
    "\n",
    "        return images"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "864bc6b5",
   "metadata": {},
   "source": [
    "# Fitting\n",
    "\n",
    "The fitting loop below pulls everything together. Each iteration samples a batch of ground truth 2D renders and their associated camera data. The mesh is projected to the same 2D images based on the camera properties and the projected mesh silhouette is compared to the ground truth silhouette (segmentation mask). Losses are calculated by comparing the two silhouettes and a gradient step backpropagates the error through the mesh parameters (that is, the bounding box rotation, translation, and scaling from the initial mesh position)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "345eca8a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "/* Put everything inside the global mpl namespace */\n",
       "/* global mpl */\n",
       "window.mpl = {};\n",
       "\n",
       "mpl.get_websocket_type = function () {\n",
       "    if (typeof WebSocket !== 'undefined') {\n",
       "        return WebSocket;\n",
       "    } else if (typeof MozWebSocket !== 'undefined') {\n",
       "        return MozWebSocket;\n",
       "    } else {\n",
       "        alert(\n",
       "            'Your browser does not have WebSocket support. ' +\n",
       "                'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
       "                'Firefox 4 and 5 are also supported but you ' +\n",
       "                'have to enable WebSockets in about:config.'\n",
       "        );\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n",
       "    this.id = figure_id;\n",
       "\n",
       "    this.ws = websocket;\n",
       "\n",
       "    this.supports_binary = this.ws.binaryType !== undefined;\n",
       "\n",
       "    if (!this.supports_binary) {\n",
       "        var warnings = document.getElementById('mpl-warnings');\n",
       "        if (warnings) {\n",
       "            warnings.style.display = 'block';\n",
       "            warnings.textContent =\n",
       "                'This browser does not support binary websocket messages. ' +\n",
       "                'Performance may be slow.';\n",
       "        }\n",
       "    }\n",
       "\n",
       "    this.imageObj = new Image();\n",
       "\n",
       "    this.context = undefined;\n",
       "    this.message = undefined;\n",
       "    this.canvas = undefined;\n",
       "    this.rubberband_canvas = undefined;\n",
       "    this.rubberband_context = undefined;\n",
       "    this.format_dropdown = undefined;\n",
       "\n",
       "    this.image_mode = 'full';\n",
       "\n",
       "    this.root = document.createElement('div');\n",
       "    this.root.setAttribute('style', 'display: inline-block');\n",
       "    this._root_extra_style(this.root);\n",
       "\n",
       "    parent_element.appendChild(this.root);\n",
       "\n",
       "    this._init_header(this);\n",
       "    this._init_canvas(this);\n",
       "    this._init_toolbar(this);\n",
       "\n",
       "    var fig = this;\n",
       "\n",
       "    this.waiting = false;\n",
       "\n",
       "    this.ws.onopen = function () {\n",
       "        fig.send_message('supports_binary', { value: fig.supports_binary });\n",
       "        fig.send_message('send_image_mode', {});\n",
       "        if (fig.ratio !== 1) {\n",
       "            fig.send_message('set_device_pixel_ratio', {\n",
       "                device_pixel_ratio: fig.ratio,\n",
       "            });\n",
       "        }\n",
       "        fig.send_message('refresh', {});\n",
       "    };\n",
       "\n",
       "    this.imageObj.onload = function () {\n",
       "        if (fig.image_mode === 'full') {\n",
       "            // Full images could contain transparency (where diff images\n",
       "            // almost always do), so we need to clear the canvas so that\n",
       "            // there is no ghosting.\n",
       "            fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
       "        }\n",
       "        fig.context.drawImage(fig.imageObj, 0, 0);\n",
       "    };\n",
       "\n",
       "    this.imageObj.onunload = function () {\n",
       "        fig.ws.close();\n",
       "    };\n",
       "\n",
       "    this.ws.onmessage = this._make_on_message_function(this);\n",
       "\n",
       "    this.ondownload = ondownload;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._init_header = function () {\n",
       "    var titlebar = document.createElement('div');\n",
       "    titlebar.classList =\n",
       "        'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n",
       "    var titletext = document.createElement('div');\n",
       "    titletext.classList = 'ui-dialog-title';\n",
       "    titletext.setAttribute(\n",
       "        'style',\n",
       "        'width: 100%; text-align: center; padding: 3px;'\n",
       "    );\n",
       "    titlebar.appendChild(titletext);\n",
       "    this.root.appendChild(titlebar);\n",
       "    this.header = titletext;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n",
       "\n",
       "mpl.figure.prototype._init_canvas = function () {\n",
       "    var fig = this;\n",
       "\n",
       "    var canvas_div = (this.canvas_div = document.createElement('div'));\n",
       "    canvas_div.setAttribute(\n",
       "        'style',\n",
       "        'border: 1px solid #ddd;' +\n",
       "            'box-sizing: content-box;' +\n",
       "            'clear: both;' +\n",
       "            'min-height: 1px;' +\n",
       "            'min-width: 1px;' +\n",
       "            'outline: 0;' +\n",
       "            'overflow: hidden;' +\n",
       "            'position: relative;' +\n",
       "            'resize: both;'\n",
       "    );\n",
       "\n",
       "    function on_keyboard_event_closure(name) {\n",
       "        return function (event) {\n",
       "            return fig.key_event(event, name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    canvas_div.addEventListener(\n",
       "        'keydown',\n",
       "        on_keyboard_event_closure('key_press')\n",
       "    );\n",
       "    canvas_div.addEventListener(\n",
       "        'keyup',\n",
       "        on_keyboard_event_closure('key_release')\n",
       "    );\n",
       "\n",
       "    this._canvas_extra_style(canvas_div);\n",
       "    this.root.appendChild(canvas_div);\n",
       "\n",
       "    var canvas = (this.canvas = document.createElement('canvas'));\n",
       "    canvas.classList.add('mpl-canvas');\n",
       "    canvas.setAttribute('style', 'box-sizing: content-box;');\n",
       "\n",
       "    this.context = canvas.getContext('2d');\n",
       "\n",
       "    var backingStore =\n",
       "        this.context.backingStorePixelRatio ||\n",
       "        this.context.webkitBackingStorePixelRatio ||\n",
       "        this.context.mozBackingStorePixelRatio ||\n",
       "        this.context.msBackingStorePixelRatio ||\n",
       "        this.context.oBackingStorePixelRatio ||\n",
       "        this.context.backingStorePixelRatio ||\n",
       "        1;\n",
       "\n",
       "    this.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
       "\n",
       "    var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n",
       "        'canvas'\n",
       "    ));\n",
       "    rubberband_canvas.setAttribute(\n",
       "        'style',\n",
       "        'box-sizing: content-box; position: absolute; left: 0; top: 0; z-index: 1;'\n",
       "    );\n",
       "\n",
       "    // Apply a ponyfill if ResizeObserver is not implemented by browser.\n",
       "    if (this.ResizeObserver === undefined) {\n",
       "        if (window.ResizeObserver !== undefined) {\n",
       "            this.ResizeObserver = window.ResizeObserver;\n",
       "        } else {\n",
       "            var obs = _JSXTOOLS_RESIZE_OBSERVER({});\n",
       "            this.ResizeObserver = obs.ResizeObserver;\n",
       "        }\n",
       "    }\n",
       "\n",
       "    this.resizeObserverInstance = new this.ResizeObserver(function (entries) {\n",
       "        var nentries = entries.length;\n",
       "        for (var i = 0; i < nentries; i++) {\n",
       "            var entry = entries[i];\n",
       "            var width, height;\n",
       "            if (entry.contentBoxSize) {\n",
       "                if (entry.contentBoxSize instanceof Array) {\n",
       "                    // Chrome 84 implements new version of spec.\n",
       "                    width = entry.contentBoxSize[0].inlineSize;\n",
       "                    height = entry.contentBoxSize[0].blockSize;\n",
       "                } else {\n",
       "                    // Firefox implements old version of spec.\n",
       "                    width = entry.contentBoxSize.inlineSize;\n",
       "                    height = entry.contentBoxSize.blockSize;\n",
       "                }\n",
       "            } else {\n",
       "                // Chrome <84 implements even older version of spec.\n",
       "                width = entry.contentRect.width;\n",
       "                height = entry.contentRect.height;\n",
       "            }\n",
       "\n",
       "            // Keep the size of the canvas and rubber band canvas in sync with\n",
       "            // the canvas container.\n",
       "            if (entry.devicePixelContentBoxSize) {\n",
       "                // Chrome 84 implements new version of spec.\n",
       "                canvas.setAttribute(\n",
       "                    'width',\n",
       "                    entry.devicePixelContentBoxSize[0].inlineSize\n",
       "                );\n",
       "                canvas.setAttribute(\n",
       "                    'height',\n",
       "                    entry.devicePixelContentBoxSize[0].blockSize\n",
       "                );\n",
       "            } else {\n",
       "                canvas.setAttribute('width', width * fig.ratio);\n",
       "                canvas.setAttribute('height', height * fig.ratio);\n",
       "            }\n",
       "            canvas.setAttribute(\n",
       "                'style',\n",
       "                'width: ' + width + 'px; height: ' + height + 'px;'\n",
       "            );\n",
       "\n",
       "            rubberband_canvas.setAttribute('width', width);\n",
       "            rubberband_canvas.setAttribute('height', height);\n",
       "\n",
       "            // And update the size in Python. We ignore the initial 0/0 size\n",
       "            // that occurs as the element is placed into the DOM, which should\n",
       "            // otherwise not happen due to the minimum size styling.\n",
       "            if (fig.ws.readyState == 1 && width != 0 && height != 0) {\n",
       "                fig.request_resize(width, height);\n",
       "            }\n",
       "        }\n",
       "    });\n",
       "    this.resizeObserverInstance.observe(canvas_div);\n",
       "\n",
       "    function on_mouse_event_closure(name) {\n",
       "        return function (event) {\n",
       "            return fig.mouse_event(event, name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mousedown',\n",
       "        on_mouse_event_closure('button_press')\n",
       "    );\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mouseup',\n",
       "        on_mouse_event_closure('button_release')\n",
       "    );\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'dblclick',\n",
       "        on_mouse_event_closure('dblclick')\n",
       "    );\n",
       "    // Throttle sequential mouse events to 1 every 20ms.\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mousemove',\n",
       "        on_mouse_event_closure('motion_notify')\n",
       "    );\n",
       "\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mouseenter',\n",
       "        on_mouse_event_closure('figure_enter')\n",
       "    );\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mouseleave',\n",
       "        on_mouse_event_closure('figure_leave')\n",
       "    );\n",
       "\n",
       "    canvas_div.addEventListener('wheel', function (event) {\n",
       "        if (event.deltaY < 0) {\n",
       "            event.step = 1;\n",
       "        } else {\n",
       "            event.step = -1;\n",
       "        }\n",
       "        on_mouse_event_closure('scroll')(event);\n",
       "    });\n",
       "\n",
       "    canvas_div.appendChild(canvas);\n",
       "    canvas_div.appendChild(rubberband_canvas);\n",
       "\n",
       "    this.rubberband_context = rubberband_canvas.getContext('2d');\n",
       "    this.rubberband_context.strokeStyle = '#000000';\n",
       "\n",
       "    this._resize_canvas = function (width, height, forward) {\n",
       "        if (forward) {\n",
       "            canvas_div.style.width = width + 'px';\n",
       "            canvas_div.style.height = height + 'px';\n",
       "        }\n",
       "    };\n",
       "\n",
       "    // Disable right mouse context menu.\n",
       "    this.rubberband_canvas.addEventListener('contextmenu', function (_e) {\n",
       "        event.preventDefault();\n",
       "        return false;\n",
       "    });\n",
       "\n",
       "    function set_focus() {\n",
       "        canvas.focus();\n",
       "        canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    window.setTimeout(set_focus, 100);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function () {\n",
       "    var fig = this;\n",
       "\n",
       "    var toolbar = document.createElement('div');\n",
       "    toolbar.classList = 'mpl-toolbar';\n",
       "    this.root.appendChild(toolbar);\n",
       "\n",
       "    function on_click_closure(name) {\n",
       "        return function (_event) {\n",
       "            return fig.toolbar_button_onclick(name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    function on_mouseover_closure(tooltip) {\n",
       "        return function (event) {\n",
       "            if (!event.currentTarget.disabled) {\n",
       "                return fig.toolbar_button_onmouseover(tooltip);\n",
       "            }\n",
       "        };\n",
       "    }\n",
       "\n",
       "    fig.buttons = {};\n",
       "    var buttonGroup = document.createElement('div');\n",
       "    buttonGroup.classList = 'mpl-button-group';\n",
       "    for (var toolbar_ind in mpl.toolbar_items) {\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) {\n",
       "            /* Instead of a spacer, we start a new button group. */\n",
       "            if (buttonGroup.hasChildNodes()) {\n",
       "                toolbar.appendChild(buttonGroup);\n",
       "            }\n",
       "            buttonGroup = document.createElement('div');\n",
       "            buttonGroup.classList = 'mpl-button-group';\n",
       "            continue;\n",
       "        }\n",
       "\n",
       "        var button = (fig.buttons[name] = document.createElement('button'));\n",
       "        button.classList = 'mpl-widget';\n",
       "        button.setAttribute('role', 'button');\n",
       "        button.setAttribute('aria-disabled', 'false');\n",
       "        button.addEventListener('click', on_click_closure(method_name));\n",
       "        button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n",
       "\n",
       "        var icon_img = document.createElement('img');\n",
       "        icon_img.src = '_images/' + image + '.png';\n",
       "        icon_img.srcset = '_images/' + image + '_large.png 2x';\n",
       "        icon_img.alt = tooltip;\n",
       "        button.appendChild(icon_img);\n",
       "\n",
       "        buttonGroup.appendChild(button);\n",
       "    }\n",
       "\n",
       "    if (buttonGroup.hasChildNodes()) {\n",
       "        toolbar.appendChild(buttonGroup);\n",
       "    }\n",
       "\n",
       "    var fmt_picker = document.createElement('select');\n",
       "    fmt_picker.classList = 'mpl-widget';\n",
       "    toolbar.appendChild(fmt_picker);\n",
       "    this.format_dropdown = fmt_picker;\n",
       "\n",
       "    for (var ind in mpl.extensions) {\n",
       "        var fmt = mpl.extensions[ind];\n",
       "        var option = document.createElement('option');\n",
       "        option.selected = fmt === mpl.default_extension;\n",
       "        option.innerHTML = fmt;\n",
       "        fmt_picker.appendChild(option);\n",
       "    }\n",
       "\n",
       "    var status_bar = document.createElement('span');\n",
       "    status_bar.classList = 'mpl-message';\n",
       "    toolbar.appendChild(status_bar);\n",
       "    this.message = status_bar;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n",
       "    // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
       "    // which will in turn request a refresh of the image.\n",
       "    this.send_message('resize', { width: x_pixels, height: y_pixels });\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.send_message = function (type, properties) {\n",
       "    properties['type'] = type;\n",
       "    properties['figure_id'] = this.id;\n",
       "    this.ws.send(JSON.stringify(properties));\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.send_draw_message = function () {\n",
       "    if (!this.waiting) {\n",
       "        this.waiting = true;\n",
       "        this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_save = function (fig, _msg) {\n",
       "    var format_dropdown = fig.format_dropdown;\n",
       "    var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
       "    fig.ondownload(fig, format);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_resize = function (fig, msg) {\n",
       "    var size = msg['size'];\n",
       "    if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n",
       "        fig._resize_canvas(size[0], size[1], msg['forward']);\n",
       "        fig.send_message('refresh', {});\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_rubberband = function (fig, msg) {\n",
       "    var x0 = msg['x0'] / fig.ratio;\n",
       "    var y0 = (fig.canvas.height - msg['y0']) / fig.ratio;\n",
       "    var x1 = msg['x1'] / fig.ratio;\n",
       "    var y1 = (fig.canvas.height - msg['y1']) / fig.ratio;\n",
       "    x0 = Math.floor(x0) + 0.5;\n",
       "    y0 = Math.floor(y0) + 0.5;\n",
       "    x1 = Math.floor(x1) + 0.5;\n",
       "    y1 = Math.floor(y1) + 0.5;\n",
       "    var min_x = Math.min(x0, x1);\n",
       "    var min_y = Math.min(y0, y1);\n",
       "    var width = Math.abs(x1 - x0);\n",
       "    var height = Math.abs(y1 - y0);\n",
       "\n",
       "    fig.rubberband_context.clearRect(\n",
       "        0,\n",
       "        0,\n",
       "        fig.canvas.width / fig.ratio,\n",
       "        fig.canvas.height / fig.ratio\n",
       "    );\n",
       "\n",
       "    fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_figure_label = function (fig, msg) {\n",
       "    // Updates the figure title.\n",
       "    fig.header.textContent = msg['label'];\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_cursor = function (fig, msg) {\n",
       "    fig.rubberband_canvas.style.cursor = msg['cursor'];\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_message = function (fig, msg) {\n",
       "    fig.message.textContent = msg['message'];\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_draw = function (fig, _msg) {\n",
       "    // Request the server to send over a new figure.\n",
       "    fig.send_draw_message();\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_image_mode = function (fig, msg) {\n",
       "    fig.image_mode = msg['mode'];\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n",
       "    for (var key in msg) {\n",
       "        if (!(key in fig.buttons)) {\n",
       "            continue;\n",
       "        }\n",
       "        fig.buttons[key].disabled = !msg[key];\n",
       "        fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n",
       "    if (msg['mode'] === 'PAN') {\n",
       "        fig.buttons['Pan'].classList.add('active');\n",
       "        fig.buttons['Zoom'].classList.remove('active');\n",
       "    } else if (msg['mode'] === 'ZOOM') {\n",
       "        fig.buttons['Pan'].classList.remove('active');\n",
       "        fig.buttons['Zoom'].classList.add('active');\n",
       "    } else {\n",
       "        fig.buttons['Pan'].classList.remove('active');\n",
       "        fig.buttons['Zoom'].classList.remove('active');\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function () {\n",
       "    // Called whenever the canvas gets updated.\n",
       "    this.send_message('ack', {});\n",
       "};\n",
       "\n",
       "// A function to construct a web socket function for onmessage handling.\n",
       "// Called in the figure constructor.\n",
       "mpl.figure.prototype._make_on_message_function = function (fig) {\n",
       "    return function socket_on_message(evt) {\n",
       "        if (evt.data instanceof Blob) {\n",
       "            var img = evt.data;\n",
       "            if (img.type !== 'image/png') {\n",
       "                /* FIXME: We get \"Resource interpreted as Image but\n",
       "                 * transferred with MIME type text/plain:\" errors on\n",
       "                 * Chrome.  But how to set the MIME type?  It doesn't seem\n",
       "                 * to be part of the websocket stream */\n",
       "                img.type = 'image/png';\n",
       "            }\n",
       "\n",
       "            /* Free the memory for the previous frames */\n",
       "            if (fig.imageObj.src) {\n",
       "                (window.URL || window.webkitURL).revokeObjectURL(\n",
       "                    fig.imageObj.src\n",
       "                );\n",
       "            }\n",
       "\n",
       "            fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
       "                img\n",
       "            );\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        } else if (\n",
       "            typeof evt.data === 'string' &&\n",
       "            evt.data.slice(0, 21) === 'data:image/png;base64'\n",
       "        ) {\n",
       "            fig.imageObj.src = evt.data;\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        var msg = JSON.parse(evt.data);\n",
       "        var msg_type = msg['type'];\n",
       "\n",
       "        // Call the  \"handle_{type}\" callback, which takes\n",
       "        // the figure and JSON message as its only arguments.\n",
       "        try {\n",
       "            var callback = fig['handle_' + msg_type];\n",
       "        } catch (e) {\n",
       "            console.log(\n",
       "                \"No handler for the '\" + msg_type + \"' message type: \",\n",
       "                msg\n",
       "            );\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        if (callback) {\n",
       "            try {\n",
       "                // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
       "                callback(fig, msg);\n",
       "            } catch (e) {\n",
       "                console.log(\n",
       "                    \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n",
       "                    e,\n",
       "                    e.stack,\n",
       "                    msg\n",
       "                );\n",
       "            }\n",
       "        }\n",
       "    };\n",
       "};\n",
       "\n",
       "// from https://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
       "mpl.findpos = function (e) {\n",
       "    //this section is from http://www.quirksmode.org/js/events_properties.html\n",
       "    var targ;\n",
       "    if (!e) {\n",
       "        e = window.event;\n",
       "    }\n",
       "    if (e.target) {\n",
       "        targ = e.target;\n",
       "    } else if (e.srcElement) {\n",
       "        targ = e.srcElement;\n",
       "    }\n",
       "    if (targ.nodeType === 3) {\n",
       "        // defeat Safari bug\n",
       "        targ = targ.parentNode;\n",
       "    }\n",
       "\n",
       "    // pageX,Y are the mouse positions relative to the document\n",
       "    var boundingRect = targ.getBoundingClientRect();\n",
       "    var x = e.pageX - (boundingRect.left + document.body.scrollLeft);\n",
       "    var y = e.pageY - (boundingRect.top + document.body.scrollTop);\n",
       "\n",
       "    return { x: x, y: y };\n",
       "};\n",
       "\n",
       "/*\n",
       " * return a copy of an object with only non-object keys\n",
       " * we need this to avoid circular references\n",
       " * https://stackoverflow.com/a/24161582/3208463\n",
       " */\n",
       "function simpleKeys(original) {\n",
       "    return Object.keys(original).reduce(function (obj, key) {\n",
       "        if (typeof original[key] !== 'object') {\n",
       "            obj[key] = original[key];\n",
       "        }\n",
       "        return obj;\n",
       "    }, {});\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.mouse_event = function (event, name) {\n",
       "    var canvas_pos = mpl.findpos(event);\n",
       "\n",
       "    if (name === 'button_press') {\n",
       "        this.canvas.focus();\n",
       "        this.canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    var x = canvas_pos.x * this.ratio;\n",
       "    var y = canvas_pos.y * this.ratio;\n",
       "\n",
       "    this.send_message(name, {\n",
       "        x: x,\n",
       "        y: y,\n",
       "        button: event.button,\n",
       "        step: event.step,\n",
       "        guiEvent: simpleKeys(event),\n",
       "    });\n",
       "\n",
       "    /* This prevents the web browser from automatically changing to\n",
       "     * the text insertion cursor when the button is pressed.  We want\n",
       "     * to control all of the cursor setting manually through the\n",
       "     * 'cursor' event from matplotlib */\n",
       "    event.preventDefault();\n",
       "    return false;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function (_event, _name) {\n",
       "    // Handle any extra behaviour associated with a key event\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.key_event = function (event, name) {\n",
       "    // Prevent repeat events\n",
       "    if (name === 'key_press') {\n",
       "        if (event.key === this._key) {\n",
       "            return;\n",
       "        } else {\n",
       "            this._key = event.key;\n",
       "        }\n",
       "    }\n",
       "    if (name === 'key_release') {\n",
       "        this._key = null;\n",
       "    }\n",
       "\n",
       "    var value = '';\n",
       "    if (event.ctrlKey && event.key !== 'Control') {\n",
       "        value += 'ctrl+';\n",
       "    }\n",
       "    else if (event.altKey && event.key !== 'Alt') {\n",
       "        value += 'alt+';\n",
       "    }\n",
       "    else if (event.shiftKey && event.key !== 'Shift') {\n",
       "        value += 'shift+';\n",
       "    }\n",
       "\n",
       "    value += 'k' + event.key;\n",
       "\n",
       "    this._key_event_extra(event, name);\n",
       "\n",
       "    this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n",
       "    return false;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onclick = function (name) {\n",
       "    if (name === 'download') {\n",
       "        this.handle_save(this, null);\n",
       "    } else {\n",
       "        this.send_message('toolbar_button', { name: name });\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n",
       "    this.message.textContent = tooltip;\n",
       "};\n",
       "\n",
       "///////////////// REMAINING CONTENT GENERATED BY embed_js.py /////////////////\n",
       "// prettier-ignore\n",
       "var _JSXTOOLS_RESIZE_OBSERVER=function(A){var t,i=new WeakMap,n=new WeakMap,a=new WeakMap,r=new WeakMap,o=new Set;function s(e){if(!(this instanceof s))throw new TypeError(\"Constructor requires 'new' operator\");i.set(this,e)}function h(){throw new TypeError(\"Function is not a constructor\")}function c(e,t,i,n){e=0 in arguments?Number(arguments[0]):0,t=1 in arguments?Number(arguments[1]):0,i=2 in arguments?Number(arguments[2]):0,n=3 in arguments?Number(arguments[3]):0,this.right=(this.x=this.left=e)+(this.width=i),this.bottom=(this.y=this.top=t)+(this.height=n),Object.freeze(this)}function d(){t=requestAnimationFrame(d);var s=new WeakMap,p=new Set;o.forEach((function(t){r.get(t).forEach((function(i){var r=t instanceof window.SVGElement,o=a.get(t),d=r?0:parseFloat(o.paddingTop),f=r?0:parseFloat(o.paddingRight),l=r?0:parseFloat(o.paddingBottom),u=r?0:parseFloat(o.paddingLeft),g=r?0:parseFloat(o.borderTopWidth),m=r?0:parseFloat(o.borderRightWidth),w=r?0:parseFloat(o.borderBottomWidth),b=u+f,F=d+l,v=(r?0:parseFloat(o.borderLeftWidth))+m,W=g+w,y=r?0:t.offsetHeight-W-t.clientHeight,E=r?0:t.offsetWidth-v-t.clientWidth,R=b+v,z=F+W,M=r?t.width:parseFloat(o.width)-R-E,O=r?t.height:parseFloat(o.height)-z-y;if(n.has(t)){var k=n.get(t);if(k[0]===M&&k[1]===O)return}n.set(t,[M,O]);var S=Object.create(h.prototype);S.target=t,S.contentRect=new c(u,d,M,O),s.has(i)||(s.set(i,[]),p.add(i)),s.get(i).push(S)}))})),p.forEach((function(e){i.get(e).call(e,s.get(e),e)}))}return s.prototype.observe=function(i){if(i instanceof window.Element){r.has(i)||(r.set(i,new Set),o.add(i),a.set(i,window.getComputedStyle(i)));var n=r.get(i);n.has(this)||n.add(this),cancelAnimationFrame(t),t=requestAnimationFrame(d)}},s.prototype.unobserve=function(i){if(i instanceof window.Element&&r.has(i)){var n=r.get(i);n.has(this)&&(n.delete(this),n.size||(r.delete(i),o.delete(i))),n.size||r.delete(i),o.size||cancelAnimationFrame(t)}},A.DOMRectReadOnly=c,A.ResizeObserver=s,A.ResizeObserverEntry=h,A}; // eslint-disable-line\n",
       "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
       "\n",
       "mpl.extensions = [\"eps\", \"jpeg\", \"pgf\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
       "\n",
       "mpl.default_extension = \"png\";/* global mpl */\n",
       "\n",
       "var comm_websocket_adapter = function (comm) {\n",
       "    // Create a \"websocket\"-like object which calls the given IPython comm\n",
       "    // object with the appropriate methods. Currently this is a non binary\n",
       "    // socket, so there is still some room for performance tuning.\n",
       "    var ws = {};\n",
       "\n",
       "    ws.binaryType = comm.kernel.ws.binaryType;\n",
       "    ws.readyState = comm.kernel.ws.readyState;\n",
       "    function updateReadyState(_event) {\n",
       "        if (comm.kernel.ws) {\n",
       "            ws.readyState = comm.kernel.ws.readyState;\n",
       "        } else {\n",
       "            ws.readyState = 3; // Closed state.\n",
       "        }\n",
       "    }\n",
       "    comm.kernel.ws.addEventListener('open', updateReadyState);\n",
       "    comm.kernel.ws.addEventListener('close', updateReadyState);\n",
       "    comm.kernel.ws.addEventListener('error', updateReadyState);\n",
       "\n",
       "    ws.close = function () {\n",
       "        comm.close();\n",
       "    };\n",
       "    ws.send = function (m) {\n",
       "        //console.log('sending', m);\n",
       "        comm.send(m);\n",
       "    };\n",
       "    // Register the callback with on_msg.\n",
       "    comm.on_msg(function (msg) {\n",
       "        //console.log('receiving', msg['content']['data'], msg);\n",
       "        var data = msg['content']['data'];\n",
       "        if (data['blob'] !== undefined) {\n",
       "            data = {\n",
       "                data: new Blob(msg['buffers'], { type: data['blob'] }),\n",
       "            };\n",
       "        }\n",
       "        // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
       "        ws.onmessage(data);\n",
       "    });\n",
       "    return ws;\n",
       "};\n",
       "\n",
       "mpl.mpl_figure_comm = function (comm, msg) {\n",
       "    // This is the function which gets called when the mpl process\n",
       "    // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
       "\n",
       "    var id = msg.content.data.id;\n",
       "    // Get hold of the div created by the display call when the Comm\n",
       "    // socket was opened in Python.\n",
       "    var element = document.getElementById(id);\n",
       "    var ws_proxy = comm_websocket_adapter(comm);\n",
       "\n",
       "    function ondownload(figure, _format) {\n",
       "        window.open(figure.canvas.toDataURL());\n",
       "    }\n",
       "\n",
       "    var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n",
       "\n",
       "    // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
       "    // web socket which is closed, not our websocket->open comm proxy.\n",
       "    ws_proxy.onopen();\n",
       "\n",
       "    fig.parent_element = element;\n",
       "    fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
       "    if (!fig.cell_info) {\n",
       "        console.error('Failed to find cell for figure', id, fig);\n",
       "        return;\n",
       "    }\n",
       "    fig.cell_info[0].output_area.element.on(\n",
       "        'cleared',\n",
       "        { fig: fig },\n",
       "        fig._remove_fig_handler\n",
       "    );\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_close = function (fig, msg) {\n",
       "    var width = fig.canvas.width / fig.ratio;\n",
       "    fig.cell_info[0].output_area.element.off(\n",
       "        'cleared',\n",
       "        fig._remove_fig_handler\n",
       "    );\n",
       "    fig.resizeObserverInstance.unobserve(fig.canvas_div);\n",
       "\n",
       "    // Update the output cell to use the data from the current canvas.\n",
       "    fig.push_to_output();\n",
       "    var dataURL = fig.canvas.toDataURL();\n",
       "    // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
       "    // the notebook keyboard shortcuts fail.\n",
       "    IPython.keyboard_manager.enable();\n",
       "    fig.parent_element.innerHTML =\n",
       "        '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
       "    fig.close_ws(fig, msg);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.close_ws = function (fig, msg) {\n",
       "    fig.send_message('closing', msg);\n",
       "    // fig.ws.close()\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.push_to_output = function (_remove_interactive) {\n",
       "    // Turn the data on the canvas into data in the output cell.\n",
       "    var width = this.canvas.width / this.ratio;\n",
       "    var dataURL = this.canvas.toDataURL();\n",
       "    this.cell_info[1]['text/html'] =\n",
       "        '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function () {\n",
       "    // Tell IPython that the notebook contents must change.\n",
       "    IPython.notebook.set_dirty(true);\n",
       "    this.send_message('ack', {});\n",
       "    var fig = this;\n",
       "    // Wait a second, then push the new image to the DOM so\n",
       "    // that it is saved nicely (might be nice to debounce this).\n",
       "    setTimeout(function () {\n",
       "        fig.push_to_output();\n",
       "    }, 1000);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function () {\n",
       "    var fig = this;\n",
       "\n",
       "    var toolbar = document.createElement('div');\n",
       "    toolbar.classList = 'btn-toolbar';\n",
       "    this.root.appendChild(toolbar);\n",
       "\n",
       "    function on_click_closure(name) {\n",
       "        return function (_event) {\n",
       "            return fig.toolbar_button_onclick(name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    function on_mouseover_closure(tooltip) {\n",
       "        return function (event) {\n",
       "            if (!event.currentTarget.disabled) {\n",
       "                return fig.toolbar_button_onmouseover(tooltip);\n",
       "            }\n",
       "        };\n",
       "    }\n",
       "\n",
       "    fig.buttons = {};\n",
       "    var buttonGroup = document.createElement('div');\n",
       "    buttonGroup.classList = 'btn-group';\n",
       "    var button;\n",
       "    for (var toolbar_ind in mpl.toolbar_items) {\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) {\n",
       "            /* Instead of a spacer, we start a new button group. */\n",
       "            if (buttonGroup.hasChildNodes()) {\n",
       "                toolbar.appendChild(buttonGroup);\n",
       "            }\n",
       "            buttonGroup = document.createElement('div');\n",
       "            buttonGroup.classList = 'btn-group';\n",
       "            continue;\n",
       "        }\n",
       "\n",
       "        button = fig.buttons[name] = document.createElement('button');\n",
       "        button.classList = 'btn btn-default';\n",
       "        button.href = '#';\n",
       "        button.title = name;\n",
       "        button.innerHTML = '<i class=\"fa ' + image + ' fa-lg\"></i>';\n",
       "        button.addEventListener('click', on_click_closure(method_name));\n",
       "        button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n",
       "        buttonGroup.appendChild(button);\n",
       "    }\n",
       "\n",
       "    if (buttonGroup.hasChildNodes()) {\n",
       "        toolbar.appendChild(buttonGroup);\n",
       "    }\n",
       "\n",
       "    // Add the status bar.\n",
       "    var status_bar = document.createElement('span');\n",
       "    status_bar.classList = 'mpl-message pull-right';\n",
       "    toolbar.appendChild(status_bar);\n",
       "    this.message = status_bar;\n",
       "\n",
       "    // Add the close button to the window.\n",
       "    var buttongrp = document.createElement('div');\n",
       "    buttongrp.classList = 'btn-group inline pull-right';\n",
       "    button = document.createElement('button');\n",
       "    button.classList = 'btn btn-mini btn-primary';\n",
       "    button.href = '#';\n",
       "    button.title = 'Stop Interaction';\n",
       "    button.innerHTML = '<i class=\"fa fa-power-off icon-remove icon-large\"></i>';\n",
       "    button.addEventListener('click', function (_evt) {\n",
       "        fig.handle_close(fig, {});\n",
       "    });\n",
       "    button.addEventListener(\n",
       "        'mouseover',\n",
       "        on_mouseover_closure('Stop Interaction')\n",
       "    );\n",
       "    buttongrp.appendChild(button);\n",
       "    var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n",
       "    titlebar.insertBefore(buttongrp, titlebar.firstChild);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._remove_fig_handler = function (event) {\n",
       "    var fig = event.data.fig;\n",
       "    if (event.target !== this) {\n",
       "        // Ignore bubbled events from children.\n",
       "        return;\n",
       "    }\n",
       "    fig.close_ws(fig, {});\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function (el) {\n",
       "    el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function (el) {\n",
       "    // this is important to make the div 'focusable\n",
       "    el.setAttribute('tabindex', 0);\n",
       "    // reach out to IPython and tell the keyboard manager to turn it's self\n",
       "    // off when our div gets focus\n",
       "\n",
       "    // location in version 3\n",
       "    if (IPython.notebook.keyboard_manager) {\n",
       "        IPython.notebook.keyboard_manager.register_events(el);\n",
       "    } else {\n",
       "        // location in version 2\n",
       "        IPython.keyboard_manager.register_events(el);\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function (event, _name) {\n",
       "    // Check for shift+enter\n",
       "    if (event.shiftKey && event.which === 13) {\n",
       "        this.canvas_div.blur();\n",
       "        // select the cell after this one\n",
       "        var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
       "        IPython.notebook.select(index + 1);\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_save = function (fig, _msg) {\n",
       "    fig.ondownload(fig, null);\n",
       "};\n",
       "\n",
       "mpl.find_output_cell = function (html_output) {\n",
       "    // Return the cell and output element which can be found *uniquely* in the notebook.\n",
       "    // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
       "    // IPython event is triggered only after the cells have been serialised, which for\n",
       "    // our purposes (turning an active figure into a static one), is too late.\n",
       "    var cells = IPython.notebook.get_cells();\n",
       "    var ncells = cells.length;\n",
       "    for (var i = 0; i < ncells; i++) {\n",
       "        var cell = cells[i];\n",
       "        if (cell.cell_type === 'code') {\n",
       "            for (var j = 0; j < cell.output_area.outputs.length; j++) {\n",
       "                var data = cell.output_area.outputs[j];\n",
       "                if (data.data) {\n",
       "                    // IPython >= 3 moved mimebundle to data attribute of output\n",
       "                    data = data.data;\n",
       "                }\n",
       "                if (data['text/html'] === html_output) {\n",
       "                    return [cell, data, j];\n",
       "                }\n",
       "            }\n",
       "        }\n",
       "    }\n",
       "};\n",
       "\n",
       "// Register the function which deals with the matplotlib target/channel.\n",
       "// The kernel may be null if the page has been refreshed.\n",
       "if (IPython.notebook.kernel !== null) {\n",
       "    IPython.notebook.kernel.comm_manager.register_target(\n",
       "        'matplotlib',\n",
       "        mpl.mpl_figure_comm\n",
       "    );\n",
       "}\n"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<img src=\"\" width=\"640\">"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss on epoch 0: 0.803777813911438\n",
      "loss on epoch 1: 0.8190177083015442\n",
      "loss on epoch 2: 0.858698308467865\n",
      "loss on epoch 3: 0.8473163843154907\n",
      "loss on epoch 4: 0.808709442615509\n",
      "loss on epoch 5: 0.8570806384086609\n",
      "loss on epoch 6: 0.8125276565551758\n",
      "loss on epoch 7: 0.8104761242866516\n",
      "loss on epoch 8: 0.8296130895614624\n",
      "loss on epoch 9: 0.813572108745575\n",
      "loss on epoch 10: 0.8295480012893677\n",
      "loss on epoch 11: 0.8748180866241455\n",
      "loss on epoch 12: 0.8674363493919373\n",
      "loss on epoch 13: 0.8595669865608215\n",
      "loss on epoch 14: 0.8343032002449036\n",
      "loss on epoch 15: 0.7993026971817017\n",
      "loss on epoch 16: 0.8165237903594971\n",
      "loss on epoch 17: 0.8142260313034058\n",
      "loss on epoch 18: 0.8418599963188171\n",
      "loss on epoch 19: 0.8194049596786499\n",
      "loss on epoch 20: 0.8135650157928467\n",
      "loss on epoch 21: 0.8372604250907898\n",
      "loss on epoch 22: 0.8176058530807495\n",
      "loss on epoch 23: 0.8098580837249756\n",
      "loss on epoch 24: 0.8142467737197876\n",
      "loss on epoch 25: 0.892632007598877\n"
     ]
    }
   ],
   "source": [
    "# set up for plotting progress during training\n",
    "test_batch_ids = [2, 5, 10]  # pick canonical test render views\n",
    "num_subplots = len(test_batch_ids)\n",
    "fig, ax = plt.subplots(ncols=num_subplots)\n",
    "\n",
    "# run fitting loop\n",
    "image_list = []\n",
    "for epoch in range(num_epoch):\n",
    "    for idx, data in enumerate(dataloader):\n",
    "        optim.zero_grad()  # zero out gradients for new batch\n",
    "\n",
    "        # get image mask and camera extrinsics from the true 2D renders\n",
    "        gt_mask = data[\"semantic\"].cuda()\n",
    "        camera_transform = data[\"metadata\"][\"cam_transform\"].cuda()\n",
    "        camera_projection = data[\"metadata\"][\"cam_proj\"].cuda()\n",
    "\n",
    "        # project the 3D mesh (bbox) onto a canvas matching the ground truth shape\n",
    "        # and using the ground truth camera\n",
    "        img_shape = (gt_mask.shape[1], gt_mask.shape[2])\n",
    "        soft_mask = project_to_2d(\n",
    "            bbox,\n",
    "            batch_size=batch_size_hyper,\n",
    "            image_shape=img_shape,\n",
    "            camera_transform=camera_transform,\n",
    "            camera_projection=camera_projection,\n",
    "        )\n",
    "\n",
    "        # compute loss by rewarding overlap with the 2D render silhouette\n",
    "        # and rewarding masks that are bigger\n",
    "        #   mask_overlap = reward more overlap\n",
    "        mask_overlap = overlap(soft_mask, gt_mask.squeeze(-1))\n",
    "        #   mask_occupancy = reward larger masks\n",
    "        mask_occupancy = occupancy(soft_mask)\n",
    "        \n",
    "        loss = mask_occupancy * mask_occupancy_hyper + mask_overlap * mask_overlap_hyper\n",
    "\n",
    "        # propagate losses\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "\n",
    "        # view training progress\n",
    "        if idx % 10 == 0:            \n",
    "            test_viz = [train_data[idx] for idx in test_batch_ids]\n",
    "            # only keep 1 in 10 renders to reduce animation processing time\n",
    "            image_list.append(show_renders(bbox, test_viz))\n",
    "            for i in range(num_subplots):\n",
    "                ax[i].clear()\n",
    "                ax[i].imshow(image_list[-1][i])\n",
    "            fig.canvas.draw()\n",
    "\n",
    "    # step the learning rate schedule each epoch\n",
    "    # reduces the size of changes to the bbox parameters to fine-tune in later stages of fitting\n",
    "    scheduler.step()\n",
    "    print(f\"loss on epoch {epoch:<}: {loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2d19daf",
   "metadata": {},
   "source": [
    "## Animation\n",
    "\n",
    "The code below can be used to compile the renders generated into an animation to view later (ex: when debugging learning after the fact)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e4fbc90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate an animation of the training loop\n",
    "num_subplots = len(test_batch_ids)\n",
    "fig, ax = plt.subplots(ncols=num_subplots)\n",
    "ims = []\n",
    "for i in range(len(image_list)):\n",
    "    sp_ims = []\n",
    "    for j in range(num_subplots):\n",
    "        im = ax[j].imshow(image_list[i][j], animated=True)\n",
    "        sp_ims.append(im)\n",
    "    # show first to not have blinking animation\n",
    "    if i == 0:\n",
    "        for j in range(num_subplots):\n",
    "            ax[j].imshow(image_list[i][j])\n",
    "    ims.append(sp_ims)\n",
    "ani = animation.ArtistAnimation(fig, ims, interval=60, blit=True, repeat_delay=100)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d3be94a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ani.save(\"animation.gif\", writer=animation.PillowWriter())  # optionally save the animation"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
